import tensorflow as tf

from tfFunctionsUtils import get_joint_distributions_from_samples


# def get_generated_labels(Exp, label_generators, intervened, chosen_labels, batch_size, pre_trained=[]):
def get_generated_labels(Exp, label_generators):
    label_noises={}
    for name in Exp.label_names:
        if name not in Exp.image_labels:
            label_noises[Exp.exogenous[name]] = tf.random.normal([Exp.batch_size, Exp.NOISE_DIM], mean=0.0, stddev=1.0,
                             dtype=tf.dtypes.float32)
    conf_noises={}
    for lb in Exp.label_names:
        confounders = Exp.latent_conf[lb]
        for conf in confounders:
            conf_noises[conf] = tf.random.normal([Exp.batch_size, Exp.NOISE_DIM], mean=0.0, stddev=1.0,
                             dtype=tf.dtypes.float32)

    # **** Noises ends
    gen_labels={}
    Noises = [label_noises['ncovid_19'], conf_noises['U0']]
    Noises= tf.concat(axis=1,values=Noises)
    soft = label_generators['covid_19']([Noises], training=True)
    gen_labels['covid_19']= soft

    #******xray starts
    image_noise = tf.random.uniform([Exp.batch_size, Exp.ENCODED_DIM], minval=-1, maxval= 2)
    T= tf.random.uniform(shape=(Exp.batch_size,1), minval=0, maxval=2, dtype=tf.int32)* 6 -3
    T= tf.cast(T, tf.float32)
    N= tf.random.normal(shape=(Exp.batch_size,1), mean=0.0, stddev=0.01, dtype=tf.float32)
    TN= tf.math.multiply(T,N)
    TN= 1.5 + TN

    gen_parent0= gen_labels['covid_19'][:,0]
    gen_parent0= tf.reshape(gen_parent0, [-1,1])
    img_par= tf.math.multiply(gen_parent0, TN)


    generated_images_batch = label_generators['xray']([image_noise, img_par])  #all parents should be tensor; #there should be some noises

    norm_layer = tf.keras.layers.Normalization(mean=0., variance=1.)
    gen_labels['xray']= norm_layer(generated_images_batch)
    #***** xray ends

    #*** Pneum starts
    model1= label_generators['pneum'][0]
    model2= label_generators['pneum'][1]

    Noises = [label_noises['npneum'], conf_noises['U0']]
    Noises= tf.concat(axis=1,values=Noises)
    init_val = model1([gen_labels['xray']])  #getting image label from pre-trained model
    conf_exos_noise = tf.concat(axis=1,values=Noises)  #confounder noise
    combined = tf.concat([init_val, conf_exos_noise], 1)  #adding them both together


    upd_val = model2(combined, training=True)  #feeding them into a trainable layer
    gen_labels['pneum'] = upd_val #final value

    # *** Pneum ends

    # **Rxray
    ret= label_generators['Rxray']([gen_labels['xray']])
    gen_labels['Rxray'] = ret

    #***
    # *************
    del gen_labels['xray']

    return gen_labels


from sklearn.preprocessing import OneHotEncoder

def calculate_joint(keep_G_fake):
    covid19 = keep_G_fake[:,0:2]
    print(covid19.shape)
    covid19= tf.math.argmax(covid19, axis=1)
    covid19= tf.reshape(covid19, [-1,1])
    print(covid19.shape)

    pneum= keep_G_fake[:,2:4]
    print(pneum.shape)
    pneum= tf.math.argmax(pneum, axis=1)
    pneum= tf.reshape(pneum, [-1,1])
    print(pneum.shape)

    # joint= tf.concat(axis=1, values= [tf.cast(covid19, tf.float32) , tf.cast(pneum, tf.float32)  ])
    joint= tf.concat(axis=1, values= [tf.cast(covid19, tf.int32) , tf.cast(pneum, tf.int32)  ])

    # uniques= tf.raw_ops.UniqueV2(x=joint, axis=[0])

    # print('Uniques', uniques)
    prob= get_joint_distributions_from_samples(['covid_19','pneum'], [2,2], joint.numpy())

    print('prob-->',prob)
    return prob

def train_D(Exp, cur_mechs, label_generators, label_discriminator, D_optimizer, data_batch, image_batch):

    print('--->',data_batch[0])
    enc = OneHotEncoder()
    enc.fit(data_batch)
    data_batch = enc.transform(data_batch).toarray()

    print('--->',data_batch[0])

    encoded_real_image= label_generators['Rxray']([image_batch])
    data_batch_wimg= tf.concat([data_batch, encoded_real_image], 1)


    G_fake= get_generated_labels(Exp, label_generators)

    with tf.GradientTape() as disc_tape:
        G_fake = tf.concat(axis=1,values=list(G_fake.values()))
        D_real = label_discriminator([data_batch_wimg], training=True)
        D_fake = label_discriminator([G_fake], training=True)
        penalty = penalty_calculation(label_discriminator, data_batch_wimg, G_fake)
        D_loss =  tf.reduce_mean(D_fake - D_real + Exp.LAMBDA_GP * penalty)

        print('-->',D_loss)

    print('disc_tape --->', disc_tape)
    gradients_of_discriminator = disc_tape.gradient(D_loss, label_discriminator.trainable_variables)
    D_optimizer.apply_gradients(zip(gradients_of_discriminator, label_discriminator.trainable_variables))


    prob= calculate_joint(data_batch)

    print('Real prob:',prob)

    return D_loss

def train_G(Exp, cur_mechs, label_generators, G_optimizers, label_discriminator, data_batch):

    # with tf.GradientTape(persistent=True) as gen_tape:
    # G_fake, m2, c19 = get_generated_labels(Exp, label_generators, {}, cur_mechs+ Exp.rep_labels, data_batch.shape[0])   ### This line is a must.
        # print('m2--->', m2)

    # ************
    with tf.GradientTape() as gen_tape:
        # **** Noises stars

        G_fake = get_generated_labels(Exp, label_generators)
        keep_G_fake= copy.deepcopy(G_fake)
        G_fake = tf.concat(axis=1,values=list(G_fake.values()))

        D_fake = label_discriminator([G_fake], training=True)
        G_loss = -tf.reduce_mean(D_fake)

        print('G_loss--->', G_loss)

     #for pneum

    # for covid_19
    print('before success')
    grad1, grad2 = gen_tape.gradient(G_loss, [label_generators['covid_19'].trainable_variables, label_generators['pneum'][1].trainable_variables])
    G_optimizers['covid_19'].apply_gradients(zip(grad1, label_generators['covid_19'].trainable_variables))
    G_optimizers['pneum'].apply_gradients(zip(grad2, label_generators['pneum'][1].trainable_variables))
    print('After success')


    #####
    del keep_G_fake['Rxray']
    keep_G_fake = tf.concat(axis=1,values=list(keep_G_fake.values()))
    prob= calculate_joint(keep_G_fake)


    print('Fake prob:',prob)

    return G_loss




def trainloop(Exp, cur_hnodes, label_generators, G_optimizers, discriminators, D_optimizers, train_dataset, tvd_diff, kl_diff):
    iteration=0
    for img_batch, label_batch in zip(train_dataset['img'], train_dataset['labels']):
        for hn, cur_mechs in cur_hnodes.items():
            # udata_batch=  data_batch['covid_19'].numpy().reshape(-1,2)
            batch1= tf.reshape(label_batch['covid_19'], [-1,1])
            batch2 = tf.reshape(label_batch['pneumonia'], [-1, 1])
            udata_batch = tf.concat(axis=1, values=[batch1, batch2])

            image_batch= img_batch


            G_loss = train_G(Exp, cur_mechs, label_generators, G_optimizers, discriminators[hn], udata_batch)


            for _ in range(Exp.CRITIC_ITERATIONS):
                D_loss = train_D(Exp, cur_mechs, label_generators, discriminators[hn],
                                                    D_optimizers[hn], udata_batch, image_batch)



            # G_loss,D_loss= train(Exp, cur_mechs, label_generators, G_optimizers, discriminators[hn],
            #                                         D_optimizers[hn], udata_batch, image_batch)

            print('Epoch [%d/%d], Step [%d/%d],' % (
                Exp.curr_epoochs + 1, Exp.num_epochs, iteration + 1, len(train_dataset)),
                  'mechanism: ', cur_mechs, ' D_loss: %.4f, G_loss: %.4f' % (D_loss.numpy(), G_loss.numpy()))

            # print('Real prob:', dprob)
            # print('Fake prob:', gprob)

        iteration+=1

        tot_iter = Exp.curr_epoochs * len(train_dataset) + iteration
        if tot_iter % 100 == 0:
            Exp.anneal_temperature(tot_iter)


        print('batch size:', udata_batch.shape)
        # break

    # if (Exp.curr_epoochs + 1) % 1 == 0:
    #     val1= data_batch['covid_19'].numpy().reshape(-1,1)
    #     # val2= data_batch['pneumonia'].numpy().reshape(-1,1)
    #     # test_batch = np.concatenate([val1, val2], axis=1)
    #     test_batch= val1
    #     csXrayEvaluation(Exp, label_generators, test_batch, tvd_diff, kl_diff)


